
# HETA_evaluation.py
# A module to implement evaluation metrics used in the HETA paper.

import numpy as np
from scipy.stats import spearmanr
from sklearn.metrics import f1_score

def sensitivity(attribution_scores_list):
    """
    Compute Sensitivity: average standard deviation across perturbations for each token.
    attribution_scores_list: list of arrays (num_perturbations x num_tokens)
    Returns: float
    """
    scores = np.array(attribution_scores_list)  # shape: (num_perturbations, num_tokens)
    stds = np.std(scores, axis=0)
    return float(np.mean(stds))

def active_passive_robustness(attr1, attr2):
    """
    Compute Active/Passive Robustness: Spearman rank correlation between two attribution rankings.
    attr1, attr2: arrays of attribution scores for aligned tokens
    Returns: float
    """
    correlation, _ = spearmanr(attr1, attr2)
    return float(correlation)

def f1_alignment(model_top_tokens, human_annotated_tokens):
    """
    Compute F1 score between model-selected top tokens and human annotations.
    model_top_tokens: list or set of indices selected by the model
    human_annotated_tokens: list or set of indices annotated by humans
    Returns: float
    """
    model_set = set(model_top_tokens)
    human_set = set(human_annotated_tokens)
    all_tokens = sorted(model_set.union(human_set))
    y_true = [1 if t in human_set else 0 for t in all_tokens]
    y_pred = [1 if t in model_set else 0 for t in all_tokens]
    return f1_score(y_true, y_pred)

def aopc(original_probs, masked_probs):
    """
    Compute AOPC (Area Over the Perturbation Curve).
    original_probs: list of probabilities with original tokens
    masked_probs: list of probabilities with progressively masked tokens
    Returns: float
    """
    original = np.array(original_probs)
    masked = np.array(masked_probs)
    diffs = original - masked
    return float(np.mean(diffs))

if __name__ == "__main__":
    # Example test
    print("Sensitivity:", sensitivity([[0.2,0.3,0.5],[0.25,0.35,0.45]]))
    print("Active/Passive Robustness:", active_passive_robustness([0.1,0.5,0.4],[0.2,0.4,0.4]))
    print("F1 Alignment:", f1_alignment([1,2,3],[2,3,4]))
    print("AOPC:", aopc([0.9,0.8,0.7],[0.85,0.75,0.65]))
